{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lxI1JxQ5-P_V"
      },
      "source": [
        "# getting ready\n",
        "安装相关依赖"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "step1 下载onnx模型与测试视频\n"
      ],
      "metadata": {
        "id": "srI36yYZ_V2g"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!gdown --fuzzy https://github.com/BabitMF/bmf/releases/download/files/models.tar.gz\n",
        "!gdown --fuzzy https://github.com/BabitMF/bmf/releases/download/files/files.tar.gz\n",
        "!tar xzvf models.tar.gz\n",
        "!tar xzvf files.tar.gz"
      ],
      "metadata": {
        "id": "ZbGVuigI_O5r"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "step2 安装BMF、onnxruntime-module"
      ],
      "metadata": {
        "id": "V5R2gGuo_k7y"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RpvU1uEd-jp8"
      },
      "outputs": [],
      "source": [
        "!pip install BabitMF\n",
        "!pip3 install onnxruntime\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "step3 获取BMF源码，找到demo模块，测试modules和model文件可以正常使用"
      ],
      "metadata": {
        "id": "cTsnAaR57zwd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!git clone https://github.com/BabitMF/bmf.git"
      ],
      "metadata": {
        "id": "YTZynL_DcR8E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!cp /content/bmf/bmf/demo/aesthetic_assessment/*.py ."
      ],
      "metadata": {
        "id": "9xpQzELLcwuZ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import bmf\n",
        "import sys\n",
        "import onnxruntime as ort\n",
        "from module_utils import SyncModule\n",
        "import aesmod_module\n",
        "import onnxruntime as ort\n",
        "import os.path as osp\n",
        "model_dir = osp.join(osp.abspath(osp.dirname('__file__')), 'models')\n",
        "aesmod_ort_model_path = osp.realpath(osp.join(model_dir, 'aes_transonnx_update3.onnx'))\n",
        "print(aesmod_ort_model_path)\n",
        "ort_session = ort.InferenceSession(aesmod_ort_model_path)"
      ],
      "metadata": {
        "id": "UJaxbai9NwsJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JLmIr9To8Wui"
      },
      "source": [
        "# source code\n",
        "\n",
        "##aesmod_module.py\n",
        "\n",
        "\n",
        "*   func get_logger()\n",
        "*   func flex_resize_aesv2()\n",
        "*   class Aesmod\n",
        "*   class BMFAesmod\n",
        "\n",
        "\n",
        "##module_utils.py\n",
        "\n",
        "\n",
        "*   class SyncModule\n",
        "\n",
        "\n",
        "##main.py\n",
        "main program for calling bmf api and visualize output\n",
        "*   func segment_decode_ticks()\n",
        "*   func get_duration()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#!/usr/bin/env python\n",
        "# -*- coding: utf-8 -*-\n",
        "### aesmod_module.py\n",
        "\n",
        "from module_utils import SyncModule\n",
        "import os\n",
        "import time\n",
        "import json\n",
        "import pdb\n",
        "import os.path as osp\n",
        "import numpy as np\n",
        "\n",
        "os.environ[\"OMP_NUM_THREADS\"] = \"8\"\n",
        "import onnxruntime as ort\n",
        "import torch\n",
        "import logging\n",
        "import cv2\n",
        "\n",
        "\n",
        "def get_logger():\n",
        "    return logging.getLogger(\"main\")\n",
        "\n",
        "\n",
        "LOGGER = get_logger()\n",
        "\n",
        "\n",
        "def flex_resize_aesv2(img, desired_size=[448, 672], pad_color=[0, 0, 0]):\n",
        "    old_h, old_w = img.shape[:2]  # old_size is in (height, width) format\n",
        "    if desired_size[0] >= desired_size[1]:\n",
        "        if old_h < old_w:  # rotate the honrizontal video\n",
        "            img = np.rot90(img, k=1, axes=(1, 0))\n",
        "    else:\n",
        "        if old_h > old_w:  # rotate the vertical video\n",
        "            img = np.rot90(img, k=1, axes=(1, 0))\n",
        "    old_h, old_w = img.shape[:2]\n",
        "\n",
        "    if old_w / old_h > (desired_size[1] / desired_size[0]):\n",
        "        ratio = desired_size[0] / old_h\n",
        "    else:\n",
        "        ratio = desired_size[1] / old_w\n",
        "    img = cv2.resize(img, None, fx=ratio, fy=ratio)\n",
        "    h, w, _ = img.shape\n",
        "    h_crop = (h - desired_size[0]) // 2\n",
        "    w_crop = (w - desired_size[1]) // 2\n",
        "    img = img[h_crop:h_crop + desired_size[0],\n",
        "              w_crop:w_crop + desired_size[1], :]\n",
        "    return img\n",
        "\n",
        "\n",
        "class Aesmod:\n",
        "\n",
        "    def __init__(self, model_path, model_version, result_path):\n",
        "        self._frm_idx = 0\n",
        "        self._frm_scores = []\n",
        "        self._model_version = model_version\n",
        "        self._result_path = result_path\n",
        "\n",
        "        # model_dir = osp.join(osp.abspath(osp.dirname(\"__file__\")), \"models\")\n",
        "        # aesmod_ort_model_path = osp.realpath(\n",
        "        #    osp.join(model_dir, \"aes_transonnx_update3.onnx\")\n",
        "        # )\n",
        "        self.use_gpu = False\n",
        "        aesmod_ort_model_path = model_path\n",
        "        print(aesmod_ort_model_path)\n",
        "        LOGGER.info(\"loading aesthetic ort inference session\")\n",
        "        self.ort_session = ort.InferenceSession(aesmod_ort_model_path)\n",
        "\n",
        "        self.resize_reso = [672, 448]\n",
        "\n",
        "    def preprocess(self, frame):\n",
        "        frame = flex_resize_aesv2(frame)\n",
        "        # print('using flex_resize_aesv2', frame.shape)\n",
        "        frame = (frame.astype(np.float32) / 255.0 -\n",
        "                 np.array([0.485, 0.456, 0.406], dtype=\"float32\")) / (np.array(\n",
        "                     [0.229, 0.224, 0.225], dtype=\"float32\"))\n",
        "        frame = np.transpose(frame, (2, 0, 1))\n",
        "        frame = np.expand_dims(frame, 0)\n",
        "        return frame\n",
        "\n",
        "    @staticmethod\n",
        "    def tensor_to_list(tensor):\n",
        "        if tensor.requires_grad:\n",
        "            return tensor.detach().cpu().flatten().tolist()\n",
        "        else:\n",
        "            return tensor.cpu().flatten().tolist()\n",
        "\n",
        "    @staticmethod\n",
        "    def score_pred_mapping(raw_scores, raw_min=2.60, raw_max=7.42):\n",
        "        pred_score = np.clip(\n",
        "            np.sum([x * (i + 1) for i, x in enumerate(raw_scores)]), raw_min,\n",
        "            raw_max)\n",
        "        pred_score = np.sqrt((pred_score - raw_min) / (raw_max - raw_min)) * 100\n",
        "        return float(np.clip(pred_score, 0, 100.0))\n",
        "\n",
        "    def process(self, frames):\n",
        "        frames = [\n",
        "            frame\n",
        "            if frame.flags[\"C_CONTIGUOUS\"] else np.ascontiguousarray(frame)\n",
        "            for frame in frames\n",
        "        ]\n",
        "        frame = self.preprocess(frames[0])\n",
        "        print(\"after preprocess shape\", frame.shape)\n",
        "        if not frame.flags[\"C_CONTIGUOUS\"]:\n",
        "            frame = np.ascontiguousarray(frame, dtype=np.float32)\n",
        "\n",
        "        t1 = time.time()\n",
        "        if self.use_gpu:\n",
        "            with torch.no_grad():\n",
        "                input_batch = torch.from_numpy(frame).contiguous().cuda()\n",
        "                preds, _ = self.trt_model(input_batch)\n",
        "                raw_score = self.tensor_to_list(preds)\n",
        "        else:\n",
        "\n",
        "            raw_score = self.ort_session.run(None, {\"input\": frame})\n",
        "            raw_score = raw_score[0][0]\n",
        "        score = self.score_pred_mapping(raw_score)\n",
        "        self._frm_scores.append(score)\n",
        "        self._frm_idx += 1\n",
        "        t2 = time.time()\n",
        "        LOGGER.info(f\"[Aesmod] inference time: {(t2 - t1) * 1000:0.1f} ms\")\n",
        "        return frames[0]\n",
        "\n",
        "    def clean(self):\n",
        "        nr_score = round(np.mean(self._frm_scores), 2)\n",
        "        results = {\n",
        "            \"aesthetic\": nr_score,\n",
        "            \"aesthetic_version\": self._model_version\n",
        "        }\n",
        "        LOGGER.info(f\"overall prediction {json.dumps(results)}\")\n",
        "        with open(self._result_path, \"w\") as outfile:\n",
        "            json.dump(results, outfile, indent=4, ensure_ascii=False)\n",
        "\n",
        "\n",
        "class BMFAesmod(SyncModule):\n",
        "\n",
        "    def __init__(self, node=None, option=None):\n",
        "        result_path = option.get(\"result_path\", 0)\n",
        "        model_version = option.get(\"model_version\", \"v1.0\")\n",
        "        model_path = option.get(\"model_path\",\n",
        "                                \"./models/aes_transonnx_update3.onnx\")\n",
        "        self._nrp = Aesmod(model_path, model_version, result_path)\n",
        "        SyncModule.__init__(self,\n",
        "                            node,\n",
        "                            nb_in=1,\n",
        "                            in_fmt=\"rgb24\",\n",
        "                            out_fmt=\"rgb24\")\n",
        "\n",
        "    def core_process(self, frames):\n",
        "        return self._nrp.process(frames)\n",
        "\n",
        "    def clean(self):\n",
        "        self._nrp.clean()"
      ],
      "metadata": {
        "id": "gSaDEmro-yBJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!cat module_utils.py"
      ],
      "metadata": {
        "id": "qdxDsiSvz5JU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3An0nCEm7zXD"
      },
      "outputs": [],
      "source": [
        "#!/usr/bin/env python3\n",
        "# -*- coding: utf-8 -*-\n",
        "\n",
        "import bmf\n",
        "import cv2, os, sys\n",
        "\n",
        "def get_duration(video_path):\n",
        "    capture = cv2.VideoCapture(video_path)\n",
        "    fps = capture.get(cv2.CAP_PROP_FPS)      # OpenCV2 version 2 used \"CV_CAP_PROP_FPS\"\n",
        "    frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))\n",
        "    duration = frame_count / fps\n",
        "    capture.release()\n",
        "    return duration\n",
        "\n",
        "def segment_decode_ticks(video_path, seg_dur=4.0, lv1_dur_thres=24.0, max_dur=1000):\n",
        "    '''\n",
        "        bmf module new decode duration ticks\n",
        "        - 0 < Duration <= 24s, 抽帧间隔r=1, 抽帧0~24帧\n",
        "        - 24s < Duration <= 600s 分片抽取, 抽帧间隔r=1, 抽帧24帧\n",
        "            - 6个4s切片, 共计6x4=24帧\n",
        "        - duration > 600s, 分8片抽帧r=1, 抽帧数量32帧\n",
        "            - (600, inf), 8个4s切片, 共计8x4=32帧\n",
        "        最大解码长度 max_dur: 1000s\n",
        "    '''\n",
        "    duration = get_duration(video_path)\n",
        "    duration_ticks = []\n",
        "    if duration < lv1_dur_thres:\n",
        "        return dict()\n",
        "    elif duration <= 600:  # medium duration\n",
        "        seg_num = 6\n",
        "        seg_intev = (duration - seg_num * seg_dur) / (seg_num - 1)\n",
        "        if seg_intev < 0.5:\n",
        "            duration_ticks.extend([0, duration])\n",
        "        else:\n",
        "            for s_i in range(seg_num):\n",
        "                seg_init = s_i * (seg_dur + seg_intev)\n",
        "                seg_end = seg_init + seg_dur\n",
        "                duration_ticks.extend([round(seg_init, 3), round(seg_end, 3)])\n",
        "    else:  # long duration\n",
        "        seg_num = 8\n",
        "        seg_intev = (min(duration, max_dur) - seg_num * seg_dur) / (seg_num - 1)\n",
        "        for s_i in range(seg_num):\n",
        "            seg_init = s_i * (seg_dur + seg_intev)\n",
        "            seg_end = seg_init + seg_dur\n",
        "            duration_ticks.extend([round(seg_init, 3), round(seg_end, 3)])\n",
        "    return {'durations': duration_ticks}\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "  input_path = \"files/bbb_360_20s.mp4\"\n",
        "  outp_path = 'res2.json'\n",
        "\n",
        "  option = dict()\n",
        "  option['result_path'] = outp_path\n",
        "  print('option',option)\n",
        "  duration_segs = segment_decode_ticks(input_path)\n",
        "  decode_params = {'input_path': input_path, 'video_params': {'extract_frames': {'fps': 1}}}\n",
        "  decode_params.update(duration_segs)\n",
        "  print('decode_params',decode_params)\n",
        "  # module process\n",
        "\n",
        "  py_module_path = os.path.abspath(os.path.dirname(os.path.dirname('__file__')))\n",
        "  py_entry = '__main__.BMFAesmod'\n",
        "  print(py_module_path, py_entry)\n",
        "\n",
        "  streams = bmf.graph().decode(decode_params)\n",
        "  video_stream = streams['video'].module('aesmod_module',\n",
        "                                        option,\n",
        "                                        py_module_path,\n",
        "                                        py_entry)\n",
        "  video_stream.run()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "!cat res2.json"
      ],
      "metadata": {
        "id": "IgzO4K01jdjU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "res2.json的aesthetic给视频打分。评分区间映射至[0, 100]，分数越高代表画面的美学水平越高;  其中在AVA数据集上好坏分类的阈值5分映射为70.5分。\n",
        "在对画面的分析中，美学是一个重要的维度。对于画面客观质量相似的画面，美学维度的评价能够进一步丰富对人类感知的拟合。对于图像美感的评价综合了色彩相关（亮度/饱和度/色彩丰富度...）、构图相关（三分线构图/对称性/前景背景对比）、语义相关（主题是否明确）、画质相关（纹理是否丰富清晰）等多个维度，同时除了摄影经验之外，图像的审美质量还受到情感和个人偏好的影响，例如对不同内容类型或风格的偏好。综上实际上的预测精度与主观感受仍然相差较大。"
      ],
      "metadata": {
        "id": "mKcZqOW40GQp"
      }
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
